import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
import pickle
import random
import re, string
import torch.nn.functional as F

DATA_DIR = os.environ.get('DATA_DIR', '~/dual-map/google-gemma-2-2b/large_embedding_data.pt')
print(f"DATA_DIR: {DATA_DIR}")
MODEL_SAVE_PATH = os.environ.get('MODEL_SAVE_PATH', '~/dual-map/model/google-gemma-2-2b/dual_map_mlp_model.pt')
print(f"MODEL_SAVE_PATH: {MODEL_SAVE_PATH}")
GEMMA_MODEL_NAME = os.environ.get('GEMMA_MODEL_NAME', "google/gemma-2-2b")
print(f"GEMMA_MODEL_NAME: {GEMMA_MODEL_NAME}")

########################################################
# Training related starts here
########################################################

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=1024):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        hidden_dim = hidden_dim * 2
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim),
        )
    
    def forward(self, x):
        return self.layers(x/np.sqrt(self.input_dim))

def load_data(filepath):
    """Load the collected data pairs from a pickle file"""
    with open(filepath, 'rb') as f:
        data_pairs = pickle.load(f)
    print(f"Loaded {len(data_pairs)} data pairs from {filepath}")
    return data_pairs

class EmbeddingDataset(Dataset):
    def __init__(self, data_pairs):
        self.data_pairs = data_pairs
    
    def __len__(self):
        return len(self.data_pairs)
    
    def __getitem__(self, idx):
        x, y = self.data_pairs[idx]
        return x, y

def prepare_data(data_pairs, test_ratio=0.05):
    """Split data into training and test sets"""
    dataset = EmbeddingDataset(data_pairs)
    
    # Calculate split sizes
    test_size = int(test_ratio * len(dataset))
    train_size = len(dataset) - test_size
    
    # Split the dataset
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    print(f"Training set: {len(train_dataset)} samples")
    print(f"Test set: {len(test_dataset)} samples")
    
    return train_dataset, test_dataset

def train_model(train_dataset, test_dataset, input_dim, output_dim, 
                batch_size=128, lr=1e-4, epochs=53):
    """Train the MLP model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Initialize model
    model = MLP(input_dim, output_dim).to(device)
    
    # Loss function and optimizer
    MSE_criterion = nn.MSELoss()
    criterion = nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    reg_loss = 0.1
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_samples = 0
        
        for x_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device) # * np.sqrt(256000  / 3072)
            
            # Forward pass
            outputs = model(x_batch)
            # loss = criterion(outputs, y_batch)
            target = torch.ones(outputs.size(0), device=y_batch.device) 
            # loss   = 0.1 * criterion(outputs, y_batch, target) + MSE_criterion(outputs, y_batch)
            loss = MSE_criterion(outputs, y_batch)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * x_batch.size(0)
            train_samples += x_batch.size(0)
        
        avg_train_loss = train_loss / train_samples

        if epoch % 1 == 0:
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_samples = 0
    
            with torch.no_grad():
                for x_batch, y_batch in test_loader:
                    x_batch = x_batch.to(device) 
                    y_batch = y_batch.to(device)
                    
                    outputs = model(x_batch)
                    # loss = criterion(outputs, y_batch)
                    target = torch.ones(outputs.size(0), device=y_batch.device)
                    # loss   = criterion(outputs, y_batch, target)
                    loss   = MSE_criterion(outputs, y_batch)
                    
                    val_loss += loss.item() * x_batch.size(0)
                    val_samples += x_batch.size(0)
            
            avg_val_loss = val_loss / val_samples
            
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"Model saved to {MODEL_SAVE_PATH}")
    
    print("Training complete!")
    return model

########################################################
# Training ends here
########################################################

# Example usage
if __name__ == "__main__":
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_NAME)
    
    # Load the data
    data_pairs = load_data(DATA_DIR)
    print(f"Dataset loaded successfully. It contains {len(data_pairs)} pairs.")
    print(f"Example of the first pair: {data_pairs[0]}")
    input_dim = data_pairs[0][0].shape[0]
    output_dim = data_pairs[0][1].shape[0]
    print(f"Input dimension: {input_dim}, Output dimension: {output_dim}")

    # Prepare data
    train_dataset, test_dataset = prepare_data(data_pairs)

    # Train the model
    _ = train_model(train_dataset, test_dataset, input_dim, output_dim)
    print("Training complete!\n\n")

    # cuda related 
    torch.cuda.empty_cache() # Clear cache
    torch.cuda.reset_peak_memory_stats()  # Reset peak memory stats if you’re tracking them
    print(torch.cuda.memory_summary(abbreviated=True)) # Check current memory usage / summary